# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""
Low level APIs for algorithms to communicate with NNI manager.
"""

from __future__ import annotations

__all__ = ['TunerCommandChannel']

import logging
import os
from collections import defaultdict
from threading import Event
from typing import Any, Callable

from nni.common.serializer import dump, load, PayloadTooLarge
from nni.runtime.command_channel.websocket import WsChannelClient
from nni.typehint import Parameters

from .command_type import (
    CommandType, TunerIncomingCommand,
    Initialize, RequestTrialJobs, UpdateSearchSpace, ReportMetricData, TrialEnd, Terminate
)

_logger = logging.getLogger(__name__)

TunerCommandCallback = Callable[[TunerIncomingCommand], None]

class TunerCommandChannel:
    """
    A channel to communicate with NNI manager.

    Each NNI experiment has a channel URL for tuner/assessor/strategy algorithm.
    The channel can only be connected once, so for each Python side :class:`~nni.experiment.Experiment` object,
    there should be exactly one corresponding ``TunerCommandChannel`` instance.

    :meth:`connect` must be invoked before sending or receiving data.

    The constructor does not have side effect so ``TunerCommandChannel`` can be created anywhere.
    But :meth:`connect` requires an initialized NNI manager, or otherwise the behavior is unpredictable.

    :meth:`_send` and :meth:`_receive` are underscore-prefixed because their signatures are scheduled to change by v3.0.

    Parameters
    ----------
    url
        The command channel URL.
        For now it must be like ``"ws://localhost:8080/tuner"`` or ``"ws://localhost:8080/url-prefix/tuner"``.
    """

    def __init__(self, url: str):
        self._channel = WsChannelClient(url)
        self._callbacks: dict[CommandType, list[Callable[..., None]]] = defaultdict(list)

    def connect(self) -> None:
        self._channel.connect()

    def disconnect(self) -> None:
        self._channel.disconnect()

    def listen(self, stop_event: Event) -> None:
        """Listen for incoming commands.

        Call :meth:`receive` in a loop and call ``callback`` for each command,
        until ``stop_event`` is set, or a Terminate command is received.
        All commands will go into callback, including Terminate command.

        It usually runs in a separate thread.

        Parameters
        ----------
        callback
            A callback function that takes a :class:`TunerIncomingCommand` as argument.
            It's not expected to return anything.
        stop_event
            A threading event that can be used to stop the loop.
        """
        while not stop_event.is_set():
            received = self.receive()
            for callback in self._callbacks[received.command_type]:
                callback(received)

            # Two ways to stop the loop:
            # 1. The received command is a Terminate command, which is triggered by a NNI manager stop.
            # 2. The stop_event is set from another thread (possibly main thread), which could be an engine shutdown.
            if received.command_type == CommandType.Terminate:
                _logger.debug('Received command type is terminate. Stop listening.')
                stop_event.set()

    # NOTE: The semantic commands are only partial for the convenience of NAS implementation.
    # Send commands are broken into different functions and signatures.
    # Ideally it should be similar for receive commands, but we can't predict which command will appear in receive.

    def send_initialized(self) -> None:
        """Send an initialized command to NNI manager."""
        self._send(CommandType.Initialized, '')

    def send_trial(
        self,
        parameter_id: int,
        parameters: Parameters,
        parameter_source: str = 'algorithm',
        parameter_index: int = 0,
        placement_constraint: dict[str, Any] | None = None,  # TODO: Define PlacementConstraint class.
    ):
        """
        Send a new trial job to NNI manager.

        Without multi-phase in mind, one parameter = one trial.

        Parameters
        ----------
        parameter_id
            The ID of the current parameter.
            It's used by whoever calls the :meth:`send_trial` function to identify the parameters.
            In most cases, they are non-negative integers starting from 0.
        parameters
            The parameters.
        parameter_source
            The source of the parameters. ``algorithm`` means the parameters are generated by the algorithm.
            It should be left as default in most cases.
        parameter_index
            The index of the parameters. This is previously used in multi-phase, but now it's only kept for compatibility reasons.
        placement_constraint
            The placement constraint of the created trial job.
        """
        # Local import to reduce import delay.
        from nni.common.version import version_dump

        trial_dict = {
            'parameter_id': parameter_id,
            'parameters': parameters,
            'parameter_source': parameter_source,
            'parameter_index': parameter_index,
            'version_info': version_dump()
        }
        if placement_constraint is not None:
            _validate_placement_constraint(placement_constraint)
            trial_dict['placement_constraint'] = placement_constraint

        try:
            send_payload = dump(trial_dict, pickle_size_limit=int(os.getenv('PICKLE_SIZE_LIMIT', 64 * 1024)))
        except PayloadTooLarge:
            raise ValueError(
                'Serialization failed when trying to dump the model because payload too large (larger than 64 KB). '
                'This is usually caused by pickling large objects (like datasets) by mistake. '
                'See the full error traceback for details and https://nni.readthedocs.io/en/stable/NAS/Serialization.html '
                'for how to resolve such issue. '
            )

        self._send(CommandType.NewTrialJob, send_payload)

    def send_no_more_trial_jobs(self) -> None:
        """Tell NNI manager that there are no more trial jobs to send for now."""
        self._send(CommandType.NoMoreTrialJobs, '')

    def receive(self) -> TunerIncomingCommand:
        """Receives a command from NNI manager."""
        command_type, data = self._receive()
        if data:
            data = load(data)

        # NOTE: Only handles the commands that are used by NAS.
        # It uses somewhat hacky way to convert the data received from NNI manager
        # to a semantic command.
        if command_type is None:
            # This shouldn't happen. Only for robustness.
            _logger.warning('Received command is empty. Terminating...')
            return Terminate()
        elif command_type == CommandType.Terminate:
            return Terminate()
        elif command_type == CommandType.Initialize:
            if not isinstance(data, dict):
                raise TypeError(f'Initialize command data must be a dict, but got {type(data)}')
            return Initialize(data)
        elif command_type == CommandType.RequestTrialJobs:
            if not isinstance(data, int):
                raise TypeError(f'RequestTrialJobs command data must be an integer, but got {type(data)}')
            return RequestTrialJobs(data)
        elif command_type == CommandType.UpdateSearchSpace:
            if not isinstance(data, dict):
                raise TypeError(f'UpdateSearchSpace command data must be a dict, but got {type(data)}')
            return UpdateSearchSpace(data)
        elif command_type == CommandType.ReportMetricData:
            if not isinstance(data, dict):
                raise TypeError(f'ReportMetricData command data must be a dict, but got {type(data)}')
            if 'value' in data:
                data['value'] = load(data['value'])
            return ReportMetricData(**data)
        elif command_type == CommandType.TrialEnd:
            if not isinstance(data, dict):
                raise TypeError(f'TrialEnd command data must be a dict, but got {type(data)}')
            # For some reason, only one parameter (I guess the first one) shows up in the data.
            # But a trial technically is associated with multiple parameters.
            parameter_id = load(data['hyper_params'])['parameter_id']
            return TrialEnd(
                trial_job_id=data['trial_job_id'],
                parameter_ids=[parameter_id],
                event=data['event']
            )
        else:
            raise ValueError(f'Unknown command type: {command_type}')

    def on_terminate(self, callback: Callable[[Terminate], None]) -> None:
        """Register a callback for Terminate command.

        Parameters
        ----------
        callback
            A callback function that takes a :class:`Terminate` as argument.
        """
        self._callbacks[Terminate.command_type].append(callback)

    def on_initialize(self, callback: Callable[[Initialize], None]) -> None:
        """Register a callback for Initialize command.

        Parameters
        ----------
        callback
            A callback function that takes a :class:`Initialize` as argument.
        """
        self._callbacks[Initialize.command_type].append(callback)

    def on_request_trial_jobs(self, callback: Callable[[RequestTrialJobs], None]) -> None:
        """Register a callback for RequestTrialJobs command.

        Parameters
        ----------
        callback
            A callback function that takes a :class:`RequestTrialJobs` as argument.
        """
        self._callbacks[RequestTrialJobs.command_type].append(callback)

    def on_update_search_space(self, callback: Callable[[UpdateSearchSpace], None]) -> None:
        """Register a callback for UpdateSearchSpace command.

        Parameters
        ----------
        callback
            A callback function that takes a :class:`UpdateSearchSpace` as argument.
        """
        self._callbacks[UpdateSearchSpace.command_type].append(callback)

    def on_report_metric_data(self, callback: Callable[[ReportMetricData], None]) -> None:
        """Register a callback for ReportMetricData command.

        Parameters
        ----------
        callback
            A callback function that takes a :class:`ReportMetricData` as argument.
        """
        self._callbacks[ReportMetricData.command_type].append(callback)

    def on_trial_end(self, callback: Callable[[TrialEnd], None]) -> None:
        """Register a callback for TrialEnd command.

        Parameters
        ----------
        callback
            A callback function that takes a :class:`TrialEnd` as argument.
        """
        self._callbacks[TrialEnd.command_type].append(callback)

    def _send(self, command_type: CommandType, data: str) -> None:
        self._channel.send({'type': command_type.value, 'content': data})

    def _receive(self) -> tuple[CommandType, str] | tuple[None, None]:
        command = self._channel.receive()
        if command is None:
            return None, None
        else:
            return CommandType(command['type']), command.get('content', '')


def _validate_placement_constraint(placement_constraint):
    # Currently only for CGO.
    if placement_constraint is None:
        raise ValueError('placement_constraint is None')
    if not 'type' in placement_constraint:
        raise ValueError('placement_constraint must have `type`')
    if not 'gpus' in placement_constraint:
        raise ValueError('placement_constraint must have `gpus`')
    if placement_constraint['type'] not in ['None', 'GPUNumber', 'Device']:
        raise ValueError('placement_constraint.type must be either `None`,. `GPUNumber` or `Device`')
    if placement_constraint['type'] == 'None' and len(placement_constraint['gpus']) > 0:
        raise ValueError('placement_constraint.gpus must be an empty list when type == None')
    if placement_constraint['type'] == 'GPUNumber':
        if len(placement_constraint['gpus']) != 1:
            raise ValueError('placement_constraint.gpus currently only support one host when type == GPUNumber')
        for e in placement_constraint['gpus']:
            if not isinstance(e, int):
                raise ValueError('placement_constraint.gpus must be a list of number when type == GPUNumber')
    if placement_constraint['type'] == 'Device':
        for e in placement_constraint['gpus']:
            if not isinstance(e, tuple):
                raise ValueError('placement_constraint.gpus must be a list of tuple when type == Device')
            if not (len(e) == 2 and isinstance(e[0], str) and isinstance(e[1], int)):
                raise ValueError('placement_constraint.gpus`s tuple must be (str, int)')
